{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import gym\n",
    "import numpy as np\n",
    "import time\n",
    "\n",
    "from IPython.display import clear_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_frames(frames):\n",
    "    for i, frame in enumerate(frames):\n",
    "        clear_output(wait=True)\n",
    "        print(frame['frame'])\n",
    "        print(f\"Episode: {frame['episode']}\")\n",
    "        print(f\"Timestep: {i + 1}\")\n",
    "        print(f\"State: {frame['state']}\")\n",
    "        print(f\"Action: {frame['action']}\")\n",
    "        print(f\"Reward: {frame['reward']}\")\n",
    "        time.sleep(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make('Taxi-v3').env"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用 Q-Learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "q_table = np.zeros([env.observation_space.n, env.action_space.n])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_q_table(state, action, reward, next_state, alpha, gamma):\n",
    "    old_value = q_table[state, action]\n",
    "    next_max = np.max(q_table[next_state])\n",
    "    new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)\n",
    "    q_table[state, action] = new_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def epsilon_greedy_policy(state, epsilon):\n",
    "    if random.uniform(0,1) < epsilon:\n",
    "        return env.action_space.sample()\n",
    "    else:\n",
    "        return np.argmax(q_table[state])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.1\n",
    "gamma = 0.6\n",
    "epsilon = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode: 50000\n"
     ]
    }
   ],
   "source": [
    "for i in range(1, 50000 + 1):\n",
    "    state = env.reset()\n",
    "    r = 0\n",
    "    epochs = 0\n",
    "    done = False\n",
    "    \n",
    "    while not done:       \n",
    "        # In each state, we select the action by epsilon-greedy policy\n",
    "        action = epsilon_greedy_policy(state, epsilon)\n",
    "        \n",
    "        # then we perform the action and move to the next state, and receive the reward\n",
    "        next_state, reward, done, _ = env.step(action)\n",
    "        \n",
    "        # Next we update the Q value using our update_q_table function\n",
    "        # which updates the Q value by Q learning update rule\n",
    "        update_q_table(state, action, reward, next_state, alpha, gamma)\n",
    "        \n",
    "        # Finally we update the previous state as next state\n",
    "        state = next_state\n",
    "\n",
    "        # Store all the rewards obtained\n",
    "        r += reward\n",
    "\n",
    "    if i % 100 == 0:\n",
    "        clear_output(wait=True)\n",
    "        print(f\"Episode: {i}\")\n",
    "\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-2.37830323, -2.27325184, -2.40093809, -2.34339469, -9.76811481,\n",
       "       -8.60512677])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "q_table[328]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试训练好的智能体的性能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results after 1000 episodes:\n",
      "Average timesteps per episode: 12.996\n",
      "Average penalties per episode: 0.0\n"
     ]
    }
   ],
   "source": [
    "total_epochs, total_penalties = 0, 0\n",
    "episodes = 1000\n",
    "frames = []\n",
    "\n",
    "for ep in range(episodes):\n",
    "    state = env.reset()\n",
    "    epochs, penalties, reward = 0, 0, 0\n",
    "    \n",
    "    done = False\n",
    "    \n",
    "    while not done:\n",
    "        action = np.argmax(q_table[state])\n",
    "        state, reward, done, _ = env.step(action)\n",
    "        \n",
    "        if reward == -10:\n",
    "            penalities += 1\n",
    "            \n",
    "        # Put each rendered frame into dict for animation\n",
    "        frames.append({\n",
    "            'frame': env.render(mode='ansi'),\n",
    "            'episode': ep, \n",
    "            'state': state,\n",
    "            'action': action,\n",
    "            'reward': reward\n",
    "            }\n",
    "        )\n",
    "        epochs += 1\n",
    "    total_penalties += penalties\n",
    "    total_epochs += epochs\n",
    "\n",
    "print(f\"Results after {episodes} episodes:\")\n",
    "print(f\"Average timesteps per episode: {total_epochs / episodes}\")\n",
    "print(f\"Average penalties per episode: {total_penalties / episodes}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------+\n",
      "|R: | : :G|\n",
      "| : | : : |\n",
      "| : : :\u001b[42m_\u001b[0m: |\n",
      "| | : | : |\n",
      "|Y| : |\u001b[35mB\u001b[0m: |\n",
      "+---------+\n",
      "  (South)\n",
      "\n",
      "Episode: 0\n",
      "Timestep: 10\n",
      "State: 279\n",
      "Action: 0\n",
      "Reward: -1\n"
     ]
    }
   ],
   "source": [
    "print_frames(frames[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
